#!/bin/env python
import pandas as pd
import matplotlib.pyplot as plt
import json
import numpy as np
from mpl_toolkits.basemap import Basemap
import matplotlib.cm as cm
from scipy import stats

#----------
# Parameters
#----------
ncluster=3
ddomain = [123,146,24,46] 
tcols={
    "TD":"gray",
    "TS":"black",
    "C1":"blue",
    "C2":"green",
    "C3":"orange",
    "C4":"purple",
    "C5":"red",
    }
tsymbols={
    "TD":"D",
    "TS":"0",
    "C1":"1",
    "C2":"2",
    "C3":"3",
    "C4":"4",
    "C5":"5",
    }   
caps = "abcdefghijklmn"

#----------
# Subroutine
#----------
def linreg(X,Y):
    """
     return a,b in solution to y = ax + b such that root mean square distance
     between trend line and original points is minimized
    """
    N = len(X)
    Sx = Sy = Sxx = Syy = Sxy = 0.0

    for x, y in zip(X, Y):
        Sx = Sx + x
        Sy = Sy + y
        Sxx = Sxx + x*x
        Syy = Syy + y*y
        Sxy = Sxy + x*y
    det = Sxx * N - Sx * Sx
    aa = (Sxy * N - Sy * Sx)/det
    bb = (Sxx * Sy - Sx * Sxy)/det
    YY2 = aa * X + bb
    return X, YY2, aa

def mk_test(x, alpha = 0.05):
    from scipy.stats import norm
    """
    this perform the MK (Mann-Kendall) test to check if the trend is present in
    data or not

    Input:
        x:   a vector of data
        alpha: significance level

    Output:
        h: True (if trend is present) or False (if trend is absence)
        p: p value of the sifnificance test

    Examples
    --------
      >>> x = np.random.rand(100)
      >>> h,p = mk_test(x,0.05)  # meteo.dat comma delimited
    """
    n = len(x)

    # calculate S
    s = 0
    for k in range(n-1):
        for j in range(k+1,n):
            s += np.sign(x[j] - x[k])

    # calculate the unique data
    unique_x = np.unique(x)
    g = len(unique_x)

    # calculate the var(s)
    if n == g: # there is no tie
        var_s = (n*(n-1)*(2*n+5))/18
    else: # there are some ties in data
        tp = np.zeros(unique_x.shape)
        for i in range(len(unique_x)):
            tp[i] = sum(unique_x[i] == x)
        var_s = (n*(n-1)*(2*n+5) + np.sum(tp*(tp-1)*(2*tp+5)))/18

    if s>0:
        z = (s - 1)/np.sqrt(var_s)
    elif s == 0:
        z = 0
    elif s<0:
        z = (s + 1)/np.sqrt(var_s)

    # calculate the p_value
    p = 2*(1-norm.cdf(abs(z))) # two tail test
    h = abs(z) > norm.ppf(1-alpha/2)

    return h, p

#----------
# Read Data
#----------
indir1 = "./Data/Observations"
indir2 = "./Data/HiFLOR_AllForc"

precip={}
eqpt ={}

tlons={}
tlats={}
tcats={}
tcinfo={}
tcnum={}
tcnum_all={}

qsf = {}
cf = {}
wf = {}

for cl in range(1,ncluster+1,1):
    print ("read data cl=",cl)
    
    #--Precip data
    infl_precip = "%s/data_composite_precip_cl%2.2i.json" % (indir1,cl)
    print (infl_precip)
    df_ = pd.read_json(infl_precip)
    lons = df_.lon.unique()
    lats = df_.lat.unique()
    flon,flat = np.meshgrid(lons,lats) # meshgrid
    jmax,imax = np.shape(flon)
    precip[cl] = df_['rain'].to_numpy().reshape((jmax,imax))
    
    #--TC data
    infl_tc_info = "%s/data_composite_tc_info_cl%2.2i.json" % (indir1,cl)
    with open(infl_tc_info) as json_fl:
        tcinfo[cl] = json.load(json_fl)
        
    infl_tc_track = "%s/data_composite_tc_track_cl%2.2i.json" % (indir1,cl)
    df_ = pd.read_json(infl_tc_track)
    tlons[cl] = df_['lon'].to_numpy()
    tlats[cl] = df_['lat'].to_numpy()
    tcats[cl] = df_['cate'].to_numpy()

    #--Frequency Time series (Obs)
    datafl = "%s/timeseries_each.json" % (indir1)
    tcnum[cl] = pd.read_json(datafl)
    
    #--Frequency Time series (AllFloc)
    datafl2 = "%s/timeseries_each.json" % (indir2)
    print (datafl2)
    tcnum_all[cl] = pd.read_json(datafl2)

    #--Front data
    tsyear=1977
    teyear=2015
    INF="%s/data_composite_front_track_cl%2.2i.json" % (indir1,cl)
    print (INF)
    df_ = pd.read_json(INF)
    qsf[cl] = df_.loc[(df_["dist"]>500) & (df_['iflag']==1)]
    cf[cl] = df_.loc[(df_["dist"]>500) & (df_['iflag']==2)]
    wf[cl] = df_.loc[(df_["dist"]>500) & (df_['iflag']==3)]

#----------
# Plot
#----------
fig = plt.figure(figsize=(13,13)) # set figure environemnt

regname={
1: "Eastern Japan",
2: "Central Japan",
3: "Western Japan",
}

for cl in range(1,ncluster+1,1):
    
    ax1 = plt.subplot2grid((3,3),(cl-1,0), colspan=1)
    
    m = Basemap(projection='mill',llcrnrlat=ddomain[2], urcrnrlat=ddomain[3], llcrnrlon=ddomain[0],urcrnrlon=ddomain[1], 
                 lat_ts=10, resolution='l') # draw basemap
    m.drawcoastlines() # draw coast line
    meridians = np.arange(0.,360.,5.)
    m.drawmeridians(meridians,labels=[0,0,0,1],linewidth=0, fontsize=10)
    parallels = np.arange(0.,90, 5.)
    m.drawparallels(parallels,labels=[1,0,0,1],linewidth=0, fontsize=10)
    
    #--TC points
    for tlon,tlat,tcat in zip(tlons[cl],tlats[cl],tcats[cl]):
        xx2,yy2 = m(tlon,tlat)
        m.plot(xx2,yy2, marker=r'$%s$' % (tsymbols[tcat]), color=tcols[tcat], ms=6, lw=4)

    #--front 
    #--quasi-stationary front
    qlons = qsf[cl]['lon'].to_numpy()
    qlats = qsf[cl]['lat'].to_numpy()
    for lons, lats in zip(qlons,qlats):
        xx2,yy2 = m(lons,lats)
        m.plot(xx2, yy2, color='black', lw=1, alpha=0.2)

    #--precip contour
    contours=np.arange(0, 30.0, 2) # set contours
    
    cmap=cm.Blues # set colormap
    x,y=m(flon,flat) # lon,lat => x,y
    cs = m.contourf(x,y,precip[cl], contours, cmap=cmap, extend="both") # plot contours
    
   #----------------
   # Time Series
   #----------------
   #--observed
    ax1 = plt.subplot2grid((3,3),(cl-1,1), colspan=3)
    years = tcnum[cl].year
    alldata = tcnum[cl]['r%i' %(cl)].to_numpy()
    h, pval1 = mk_test(alldata, alpha=0.05)
    ax1.plot(years, alldata, "o-", ms=6, lw=2, color="k", label="Observed (P-value=%.2f)" % (pval1), zorder=99)
    if pval1 <=0.05:
        xx, yy, aa = linreg(years,alldata)
        print ("obs cl=",cl,"trend=",aa)
        ax1.plot(xx, yy, "--", color="black", lw=3)
    ax1.set_ylabel("Anomalous Days", fontsize=12)
    
   #--HiFLOR (AllForc)
    syear=1977
    eyear=2015

    data1_ = tcnum_all[cl][(tcnum_all[cl]["cl"]==cl)&(tcnum_all[cl]["year"]<=eyear)&(tcnum_all[cl]["year"]>=syear)]['freq'].to_numpy()
    years1 = np.arange(syear,eyear+1,1)
    ymax1 = len(years1)
    tmax1 = len(data1_)
    emax1 = int(tmax1/ymax1)
    allens1 = data1_.reshape((emax1,ymax1))

    ensmean1 = allens1.mean(axis=0)
    
    r90_1 = np.ma.zeros((ymax1,2)) # 90% range
    for tt in range(ymax1):
        xx1 = np.array(allens1[:,tt])
        r90_1[tt,:] = stats.poisson.interval(0.99, xx1.mean(), loc=0)

    plt.fill_between(years1, allens1[:,:].min(axis=0), allens1[:,:].max(axis=0), color="red", alpha=0.2)
    h, pval2 = mk_test(ensmean1, alpha=0.05)
    plt.plot(years1, ensmean1, "o-", color="red", label="HiFLOR (P-Value=%.2f)" % (pval2), alpha=0.6)
    if pval2 <=0.05:
        xx1, yy1, aa1 = linreg(years1, ensmean1)
        plt.plot(xx1, yy1, "--", color="red",lw=3)

   #----------------     
   # Legend
   #----------------
    plt.legend(loc=2, prop={"size":10})

   #----------------
   # Title
   #----------------
    tcinfo[cl]["regname"] = regname[cl]
    title = r"%(regname)s (Total:%(ddays)i days, TCR:%(tcer).1f%%, MI:%(meanws).1f $\rm m\ s^{-1}$, MLMI:%(maxws).1f $\rm m\ s^{-1}$)" % tcinfo[cl]
    plt.text(0.25,1.03, "(%s) %s" % (caps[cl-1],title), horizontalalignment='center',fontsize=16, transform=ax1.transAxes)
    
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

# add legend
cax = fig.add_axes([0.125, 0.10, 0.225, 0.005]) 

art = plt.colorbar(cs, cax, orientation='horizontal')
art.set_label(r'Anomaly of 5-day mean daily precipitation [$\rm mm\ day^{-1}$]', fontsize=12, x=0.85)
art.ax.tick_params(labelsize=14)  

fig.savefig("./Fig6.png")
